import copy
import csv
import json
import random
from collections import defaultdict

import torch
from datasets import DatasetDict, concatenate_datasets, load_dataset
from transformers import DataCollatorForLanguageModeling, default_data_collator

from .Base import BaseDataset, UnlearnDataset


class PII(BaseDataset):
    def __init__(self, dataset_name, with_retain=False, if_llama=False, subset=None):
        super().__init__(dataset_name, with_retain, if_llama)
        self.subset = subset
        self.dataset = defaultdict()
        self.get_dataset()

    def get_dataset(self):
        context_data = load_dataset(
            "json",
            data_files="files/data/PII/context.jsonl",
            split="train",
            cache_dir="./.cache",
        )

        email2name_data = load_dataset(
            "json",
            data_files="files/data/PII/email2name.jsonl",
            split="train",
            cache_dir="./.cache",
        )

        self.dataset["context"] = context_data
        self.dataset["email2name"] = email2name_data

    def anonymize_email(self, email):
        parts = email.split("@")
        # Replace all characters in the user name part with 'x'
        user_anonymized = "x" * len(parts[0])
        # Split the domain part into segments divided by '.', and anonymize each segment
        domain_parts = parts[1].split(".")
        # Replace all characters in each part of the domain with 'x'
        domain_anonymized = ".".join(["x" * len(part) for part in domain_parts])
        return f"{user_anonymized}@{domain_anonymized}"

    def __preprocess__(self, tokenizer):
        with open("files/data/PII/prompt_template.json") as f:
            prompt_template = json.load(f)

        def preprocess_PII(examples):
            results = {"name": [], "prompt": [], "email": []}
            for i in range(len(examples["email"])):
                target_name = examples["name"][i]
                for key,prompt in prompt_template.items():
                    if int(key) > 5:
                        continue
                    results["prompt"].append(prompt.format(target_name))
                    results["name"].append(target_name)
                    results["email"].append(examples["email"][i])
            return results

        def filter_long_text(examples):
            results = {"prompt": [], "email": []}
            for i in range(len(examples["prompt"])):
                prompt = examples["prompt"][i]
                email = examples["email"][i]
                if self.if_llama:
                    prompt = (
                        self.question_start_token + prompt + self.question_end_token
                    )

                num_question_token = len(
                    tokenizer.tokenize(
                        prompt,
                        add_special_tokens=True,
                    )
                )
                if num_question_token < 512:
                    results["prompt"].append(prompt)
                    results["email"].append(email)
            return results

        email2name_data = self.dataset["email2name"]
        email2name_data = email2name_data.map(
            preprocess_PII, batched=True, remove_columns=["name"]
        )

        context_data = self.dataset["context"]
        context_data = context_data.map(filter_long_text, batched=True)
        train_data = concatenate_datasets([email2name_data, context_data])

        def preprocess(examples):
            results = {
                "input_ids": [],
                "attention_mask": [],
                "label": [],
                "refused_label": [],
                "question_length": [],
            }
            for i in range(len(examples["prompt"])):
                prompt = examples["prompt"][i]
                email = examples["email"][i]
                refusal_email = self.anonymize_email(email)
                if self.if_llama:
                    prompt = (
                        self.question_start_token + prompt + self.question_end_token
                    )
                    responses = f"{self.answer_start_token} {email}"
                    refusal_responses = f"{self.answer_start_token} {refusal_email}"
                else:
                    responses = f"{email}"
                    refusal_responses = f"{refusal_email}"
                text = prompt + responses
                refusal_text = prompt + refusal_responses
                tokenized = tokenizer(
                    text,
                    truncation=True,
                    add_special_tokens=True,
                    max_length=512,
                )
                num_question_token = len(
                    tokenizer.tokenize(
                        prompt,
                        add_special_tokens=True,
                    )
                )

                pad_length = 512 - len(tokenized.input_ids)
                pad_input_ids = (
                    tokenized.input_ids + [tokenizer.pad_token_id] * pad_length
                )
                pad_attention_mask = tokenized.attention_mask + [0] * pad_length
                if len(tokenized.input_ids) == 512:
                    label = tokenized.input_ids
                else:
                    label = (
                        tokenized.input_ids
                        + [tokenizer.eos_token_id]
                        + [-100] * (pad_length - 1)
                    )
                for i in range(num_question_token):
                    label[i] = -100
                results["input_ids"].append(torch.tensor(pad_input_ids))
                results["attention_mask"].append(torch.tensor(pad_attention_mask))
                results["label"].append(torch.tensor(label))
                results["question_length"].append(torch.tensor(num_question_token))
                refusal_label = tokenizer(
                    refusal_text,
                    truncation=True,
                    padding=False,  # Don't pad here, we will pad later if necessary
                    add_special_tokens=True,
                )
                if len(refusal_label) < 512:
                    refusal_label = refusal_label.input_ids + [-100] * (
                        512 - len(refusal_label.input_ids)
                    )
                for i in range(num_question_token):
                    refusal_label[i] = -100
                results["refused_label"].append(torch.tensor(refusal_label))
            return results

        train_dataset = train_data.map(
            preprocess, batched=True, remove_columns=["prompt", "email"]
        )

        train_dataset.set_format(
            type="torch",
            columns=[
                "input_ids",
                "attention_mask",
                "label",
                "refused_label",
                "question_length",
            ],
        )

        self.dataset["train"] = train_dataset
        self.dataset["test"] = None

    def build_dataset(self, tokenizer):
        self.__preprocess__(tokenizer)
        return self.dataset

    def build_pretrain_dataset(self, tokenizer):
        with open("files/data/PII/prompt_template.json") as f:
            prompt_template = json.load(f)
        prompt_list = list(prompt_template.values())
        def preprocess_PII(examples):
            results = {"name": [], "prompt": [], "email": []}
            for i in range(len(examples["email"])):
                target_name = examples["name"][i]
                prompt = random.choice(prompt_list)
                results["prompt"].append(prompt.format(target_name))
                results["name"].append(target_name)
                results["email"].append(examples["email"][i])
            return results

        email2name_data = self.dataset["email2name"]
        email2name_data = email2name_data.map(
            preprocess_PII, batched=True, remove_columns=["name"]
        )

        context_data = self.dataset["context"]
        train_data = concatenate_datasets([email2name_data, context_data])

        def preprocess(examples):
            results = {"text": []}
            for i in range(len(examples["prompt"])):
                prompt = examples["prompt"][i]
                email = examples["email"][i]
                if self.if_llama:
                    prompt = (
                        self.question_start_token + prompt + self.question_end_token
                    )
                    response = f"{self.answer_start_token} {email}"
                else:
                    response = f"{email}"
                text = prompt + response

                # tokenized = tokenizer(text, truncation=True, padding=False)
                # # if len(tokenized.input_ids) < 1024:
                results["text"].append(text)
            return results

        original_column_names = train_data.column_names
        train_dataset = train_data.map(
            preprocess, batched=True, remove_columns=original_column_names
        )

        return DatasetDict({"train": train_dataset, "test": train_dataset})
